package com.example.service;

import com.example.model.AuditLog;
import com.example.repository.AuditLogRepository;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.scheduling.annotation.Async;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import jakarta.servlet.http.HttpServletRequest;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;

/**
 * 审计日志服务
 */
@Slf4j
@Service
@RequiredArgsConstructor
public class AuditLogService {

    private final AuditLogRepository auditLogRepository;
    private final ObjectMapper objectMapper;

    /**
     * 异步记录审计日志
     */
    @Async
    @Transactional
    public void logAsync(AuditLog auditLog) {
        try {
            auditLogRepository.save(auditLog);
        } catch (Exception e) {
            log.error("保存审计日志失败: {}", e.getMessage(), e);
        }
    }

    /**
     * 同步记录审计日志
     */
    @Transactional
    public void log(AuditLog auditLog) {
        try {
            auditLogRepository.save(auditLog);
        } catch (Exception e) {
            log.error("保存审计日志失败: {}", e.getMessage(), e);
        }
    }

    /**
     * 记录操作日志
     */
    public void logOperation(AuditLog.OperationType operationType, String module, String description) {
        try {
            Authentication auth = SecurityContextHolder.getContext().getAuthentication();
            if (auth == null || !auth.isAuthenticated()) {
                return;
            }

            String userId = auth.getName();
            String username = auth.getName();

            AuditLog auditLog = AuditLog.create(userId, username, operationType, module, description);
            
            // 设置请求信息
            setRequestInfo(auditLog);
            
            // 异步保存
            logAsync(auditLog);
            
        } catch (Exception e) {
            log.error("记录操作日志失败: {}", e.getMessage(), e);
        }
    }

    /**
     * 记录登录日志
     */
    public void logLogin(String userId, String username, boolean success, String errorMessage) {
        try {
            AuditLog auditLog = AuditLog.create(userId, username, AuditLog.OperationType.LOGIN, "AUTH", "用户登录");
            
            if (!success) {
                auditLog.withError(errorMessage);
                auditLog.withRiskLevel(AuditLog.RiskLevel.MEDIUM);
            }
            
            setRequestInfo(auditLog);
            logAsync(auditLog);
            
        } catch (Exception e) {
            log.error("记录登录日志失败: {}", e.getMessage(), e);
        }
    }

    /**
     * 记录登出日志
     */
    public void logLogout(String userId, String username) {
        try {
            AuditLog auditLog = AuditLog.create(userId, username, AuditLog.OperationType.LOGOUT, "AUTH", "用户登出");
            setRequestInfo(auditLog);
            logAsync(auditLog);
            
        } catch (Exception e) {
            log.error("记录登出日志失败: {}", e.getMessage(), e);
        }
    }

    /**
     * 记录敏感操作日志
     */
    public void logSensitiveOperation(AuditLog.OperationType operationType, String module, 
                                    String description, Object data) {
        try {
            Authentication auth = SecurityContextHolder.getContext().getAuthentication();
            if (auth == null || !auth.isAuthenticated()) {
                return;
            }

            String userId = auth.getName();
            String username = auth.getName();

            AuditLog auditLog = AuditLog.create(userId, username, operationType, module, description);
            auditLog.withRiskLevel(AuditLog.RiskLevel.HIGH);
            
            // 记录操作数据
            if (data != null) {
                try {
                    String dataJson = objectMapper.writeValueAsString(data);
                    auditLog.withExtraInfo(dataJson);
                } catch (Exception e) {
                    log.warn("序列化操作数据失败: {}", e.getMessage());
                }
            }
            
            setRequestInfo(auditLog);
            logAsync(auditLog);
            
        } catch (Exception e) {
            log.error("记录敏感操作日志失败: {}", e.getMessage(), e);
        }
    }

    /**
     * 记录文件操作日志
     */
    public void logFileOperation(AuditLog.OperationType operationType, String filename, long fileSize) {
        try {
            Authentication auth = SecurityContextHolder.getContext().getAuthentication();
            if (auth == null || !auth.isAuthenticated()) {
                return;
            }

            String userId = auth.getName();
            String username = auth.getName();
            String description = String.format("文件操作: %s, 大小: %d bytes", filename, fileSize);

            AuditLog auditLog = AuditLog.create(userId, username, operationType, "FILE", description);
            auditLog.withExtraInfo(String.format("{\"filename\":\"%s\",\"fileSize\":%d}", filename, fileSize));
            
            setRequestInfo(auditLog);
            logAsync(auditLog);
            
        } catch (Exception e) {
            log.error("记录文件操作日志失败: {}", e.getMessage(), e);
        }
    }

    /**
     * 设置请求信息
     */
    private void setRequestInfo(AuditLog auditLog) {
        try {
            ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
            if (attributes != null) {
                HttpServletRequest request = attributes.getRequest();
                
                auditLog.setRequestMethod(request.getMethod());
                auditLog.setRequestUrl(request.getRequestURI());
                auditLog.setClientIp(getClientIp(request));
                auditLog.setUserAgent(request.getHeader("User-Agent"));
                
                // 记录请求参数（敏感信息需要过滤）
                Map<String, String[]> paramMap = request.getParameterMap();
                if (!paramMap.isEmpty()) {
                    try {
                        String params = objectMapper.writeValueAsString(filterSensitiveParams(paramMap));
                        auditLog.setRequestParams(params);
                    } catch (Exception e) {
                        log.warn("序列化请求参数失败: {}", e.getMessage());
                    }
                }
            }
        } catch (Exception e) {
            log.warn("设置请求信息失败: {}", e.getMessage());
        }
    }

    /**
     * 过滤敏感参数
     */
    private Map<String, String[]> filterSensitiveParams(Map<String, String[]> paramMap) {
        Map<String, String[]> filteredMap = new java.util.HashMap<>(paramMap);
        
        // 移除敏感参数
        filteredMap.remove("password");
        filteredMap.remove("token");
        filteredMap.remove("secret");
        filteredMap.remove("key");
        
        return filteredMap;
    }

    /**
     * 获取客户端IP
     */
    private String getClientIp(HttpServletRequest request) {
        String ip = request.getHeader("X-Forwarded-For");
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        
        if (ip != null && ip.contains(",")) {
            ip = ip.split(",")[0].trim();
        }
        
        return ip != null ? ip : "unknown";
    }

    /**
     * 查询审计日志
     */
    public Page<AuditLog> findAuditLogs(String userId, AuditLog.OperationType operationType, 
                                       LocalDateTime startTime, LocalDateTime endTime, 
                                       int page, int size) {
        Pageable pageable = PageRequest.of(page, size, Sort.by(Sort.Direction.DESC, "operationTime"));
        
        if (userId != null && operationType != null && startTime != null && endTime != null) {
            return auditLogRepository.findByUserIdAndOperationTypeAndOperationTimeBetween(
                userId, operationType, startTime, endTime, pageable);
        } else if (userId != null && startTime != null && endTime != null) {
            return auditLogRepository.findByUserIdAndOperationTimeBetween(userId, startTime, endTime, pageable);
        } else if (operationType != null && startTime != null && endTime != null) {
            return auditLogRepository.findByOperationTypeAndOperationTimeBetween(operationType, startTime, endTime, pageable);
        } else if (startTime != null && endTime != null) {
            return auditLogRepository.findByOperationTimeBetween(startTime, endTime, pageable);
        } else {
            return auditLogRepository.findAll(pageable);
        }
    }

    /**
     * 查询失败的操作
     */
    public List<AuditLog> findFailedOperations(LocalDateTime startTime, LocalDateTime endTime) {
        return auditLogRepository.findByStatusAndOperationTimeBetween(
            AuditLog.OperationStatus.FAILED, startTime, endTime);
    }

    /**
     * 查询敏感操作
     */
    public List<AuditLog> findSensitiveOperations(LocalDateTime startTime, LocalDateTime endTime) {
        return auditLogRepository.findByRiskLevelInAndOperationTimeBetween(
            List.of(AuditLog.RiskLevel.HIGH, AuditLog.RiskLevel.CRITICAL), startTime, endTime);
    }

    /**
     * 统计操作次数
     */
    public Map<AuditLog.OperationType, Long> countOperationsByType(LocalDateTime startTime, LocalDateTime endTime) {
        List<Object[]> results = auditLogRepository.countByOperationTypeAndOperationTimeBetween(startTime, endTime);
        Map<AuditLog.OperationType, Long> countMap = new java.util.HashMap<>();
        
        for (Object[] result : results) {
            AuditLog.OperationType type = (AuditLog.OperationType) result[0];
            Long count = (Long) result[1];
            countMap.put(type, count);
        }
        
        return countMap;
    }

    /**
     * 清理过期日志
     */
    @Transactional
    public int cleanExpiredLogs(int retentionDays) {
        LocalDateTime cutoffTime = LocalDateTime.now().minusDays(retentionDays);
        return auditLogRepository.deleteByOperationTimeBefore(cutoffTime);
    }
}
